/*
*  This file is part of ygg-brute
*  Copyright (c) 2020 ygg-brute authors
*  See LICENSE for licensing information
*/

#pragma once

#include <cstdint>

#include "generic/sha512_table_def.h"
#include "cuda/common.cuh"

namespace cuda {

static __constant__ uint64_t sha512_H[] = SHA512_H_DEF;
static __constant__ uint64_t sha512_K[] = SHA512_K_DEF;

// __device__ __forceinline__ uint64_t s0(uint64_t x)
// {
//     uint64_t r;
//     asm(R"({
//         .reg.u32 xl, xh, rh, rl, t;
//         mov.b64 {xl, xh}, %1;
//         shf.r.wrap.b32 rl, xl, xh, 1;
//         shf.r.wrap.b32 rh, xh, xl, 1;

//         shf.r.wrap.b32 t, xl, xh, 8;
//         xor.b32 rl, rl, t;
//         shf.r.wrap.b32 t, xh, xl, 8;
//         xor.b32 rh, rh, t;

//         shf.r.wrap.b32 t, xl, xh, 7;
//         xor.b32 rl, rl, t;
//         shr.u32 t, xh, 7;
//         xor.b32 rh, rh, t;

//         mov.b64 %0, {rl, rh};
//     })"
//     : "=l"(r) // %0
//     : "l"(x) // %1
//     );

//     return r;
// }

// __device__ __forceinline__ uint64_t s2(uint64_t x)
// {
//     uint64_t r;
//     asm(R"({
//         .reg.u32 xl, xh, rh, rl, th, tl, t;
//         mov.b64 {xl, xh}, %1;
//         shf.r.wrap.b32 tl, xl, xh, 28;
//         shf.r.wrap.b32 th, xh, xl, 28;

//         mov.b32 rl, tl;
//         mov.b32 rh, th;

//         shf.r.wrap.b32 t, tl, th, 6;
//         xor.b32 rl, rl, t;
//         shf.r.wrap.b32 t, th, tl, 6;
//         xor.b32 rh, rh, t;

//         shf.r.wrap.b32 t, tl, th, 11;
//         xor.b32 rl, rl, t;
//         shf.r.wrap.b32 t, th, tl, 11;
//         xor.b32 rh, rh, t;

//         mov.b64 %0, {rl, rh};
//     })"
//     : "=l"(r) // %0
//     : "l"(x) // %1
//     );

//     return r;
// }

// __device__ __forceinline__ uint64_t s3(uint64_t x)
// {
//     uint64_t r;
//     asm(R"({
//         .reg.u32 xl, xh, rh, rl, th, tl, t;
//         mov.b64 {xl, xh}, %1;
//         shf.r.wrap.b32 tl, xl, xh, 14;
//         shf.r.wrap.b32 th, xh, xl, 14;

//         mov.b32 rl, tl;
//         mov.b32 rh, th;

//         shf.r.wrap.b32 t, tl, th, 4;
//         xor.b32 rl, rl, t;
//         shf.r.wrap.b32 t, th, tl, 4;
//         xor.b32 rh, rh, t;

//         shf.r.wrap.b32 t, tl, th, 27;
//         xor.b32 rl, rl, t;
//         shf.r.wrap.b32 t, th, tl, 27;
//         xor.b32 rh, rh, t;

//         mov.b64 %0, {rl, rh};
//     })"
//     : "=l"(r) // %0
//     : "l"(x) // %1
//     );

//     return r;
// }

// __device__ __forceinline__ uint64_t maj(uint64_t x, uint64_t y, uint64_t z)
// {
//     uint64_t r;

//     asm(R"({
//         .reg.u32 xhi, xlo, yhi, ylo, zhi, zlo, rhi, rlo;
//         mov.b64 {xhi, xlo}, %1;
//         mov.b64 {yhi, ylo}, %2;
//         mov.b64 {zhi, zlo}, %3;

//         lop3.b32 rhi, xhi, yhi, zhi, 0xe8;
//         lop3.b32 rlo, xlo, ylo, zlo, 0xe8;

//         mov.b64 %0, {rhi, rlo};
//     })"
//     : "=l"(r)
//     : "l"(x), "l"(y), "l"(z));

//     return r;
// }

// __device__ __forceinline__ uint64_t ch(uint64_t x, uint64_t y, uint64_t z)
// {
//     uint64_t r;

//     asm(R"({
//         .reg.u32 xhi, xlo, yhi, ylo, zhi, zlo, rhi, rlo;
//         mov.b64 {xhi, xlo}, %1;
//         mov.b64 {yhi, ylo}, %2;
//         mov.b64 {zhi, zlo}, %3;

//         lop3.b32 rhi, xhi, yhi, zhi, 0xca;
//         lop3.b32 rlo, xlo, ylo, zlo, 0xca;

//         mov.b64 %0, {rhi, rlo};
//     })"
//     : "=l"(r)
//     : "l"(x), "l"(y), "l"(z));

//     return r;
// }

__device__ __forceinline__ void put_uint64(uint64_t value, uint32_t* dst)
{
    asm(R"({
        mov.b64 {%1, %0}, %2;
    })"
    : "=r"(dst[0]), "=r"(dst[1])
    : "l"(value)
    );
}

__device__ __forceinline__ void sha512_of_32_byte_block(const uint32_t data[8], uint32_t hash[8])
{
    constexpr uint64_t W5 = 1ull << 63; // '1'
    constexpr uint64_t W15 = 0x0100; // block length (128)

    int i;
    uint64_t temp1, temp2, W[80];
    uint64_t A, B, C, D, E, F, G, H;

#   define SHR(x, n) (x >> n)
#   define ROTR(x, n) (SHR(x, n) | (x << (64 - n)))

#   define S0(x) (ROTR(x, 1) ^ ROTR(x, 8) ^ SHR(x, 7))
// #   define S0(x) s0(x)
#   define S1(x) (ROTR(x, 19) ^ ROTR(x, 61) ^ SHR(x, 6))

#   define S2(x) (ROTR(x, 28) ^ ROTR(x, 34) ^ ROTR(x, 39))
// #   define S2(x) s2(x)
#   define S3(x) (ROTR(x, 14) ^ ROTR(x, 18) ^ ROTR(x, 41))
// #   define S3(x) s3(x)

#   define F0(x, y, z) ((x & y) | (z & (x | y)))
#   define F1(x, y, z) (z ^ (x & (y ^ z)))
// #   define F0 maj
// #   define F1 ch

#   define P(a, b, c, d, e, f, g, h, x, K)      \
    {                                          \
        temp1 = h + S3(e) + F1(e, f, g) + K + x; \
        temp2 = S2(a) + F0(a, b, c);             \
        d += temp1;                              \
        h = temp1 + temp2;                       \
    }
#   pragma unroll
    for (i = 0; i < 4; i++) {
        asm(R"({
            mov.b64 %0, {%2, %1};
        })"
        : "=l"(W[i])
        : "r"(bswap_u32(data[2 * i])), "r"(bswap_u32(data[2 * i + 1]))
        );
    }
    W[i++] = W5;
    #pragma unroll
    for(; i < 15; ++i) {
        W[i] = 0;
    }
    W[i++] = W15;

#   pragma unroll 64
    for (; i < 80; i++) {
        W[i] = S1(W[i - 2]) + W[i - 7] + S0(W[i - 15]) + W[i - 16];
    }

    A = sha512_H[0];
    B = sha512_H[1];
    C = sha512_H[2];
    D = sha512_H[3];
    E = sha512_H[4];
    F = sha512_H[5];
    G = sha512_H[6];
    H = sha512_H[7];
    i = 0;

#   pragma unroll 10
    do {
        P(A, B, C, D, E, F, G, H, W[i], sha512_K[i]);
        i++;
        P(H, A, B, C, D, E, F, G, W[i], sha512_K[i]);
        i++;
        P(G, H, A, B, C, D, E, F, W[i], sha512_K[i]);
        i++;
        P(F, G, H, A, B, C, D, E, W[i], sha512_K[i]);
        i++;
        P(E, F, G, H, A, B, C, D, W[i], sha512_K[i]);
        i++;
        P(D, E, F, G, H, A, B, C, W[i], sha512_K[i]);
        i++;
        P(C, D, E, F, G, H, A, B, W[i], sha512_K[i]);
        i++;
        P(B, C, D, E, F, G, H, A, W[i], sha512_K[i]);
        i++;
    } while (i < 80);

    put_uint64(A + (sha512_H)[0], hash);
    put_uint64(B + (sha512_H)[1], hash + 2);
    put_uint64(C + (sha512_H)[2], hash + 4);
    put_uint64(D + (sha512_H)[3], hash + 6);
    // put_uint64(E + (sha512_H)[4], hash + 8);
    // put_uint64(F + (sha512_H)[5], hash + 10);
    // put_uint64(G + (sha512_H)[6], hash + 12);
    // put_uint64(H + (sha512_H)[7], hash + 14);

#   undef SHR
#   undef ROTR
#   undef S0
#   undef S1
#   undef S2
#   undef S3
#   undef F0
#   undef F1
#   undef P
}

}